#!/usr/bin/env python
# -*- coding:utf-8 -*-

import ctypes
from dapi import wintype
import pefile
import capstone
from lib import settings
from lib.utils import loader
import os
from lib.utils import printer

class debugger():

    def __init__(self,hmodule:dict):
        self.debugger_active = None
        self.pid = None
        self.h_process = None
        self.h_thread = None
        self.context = None
        self.breakpoints = {}
        self.first_breakpoints = True
        self.exception = None
        self.exception_address = None
        self.ep_address = None
        self.process_information = wintype.PROCESS_INFORMATION()
        self.hmodule = hmodule
    
    def load_64_exe(self,path_to_exe,bAddr=None):
        creation_flags = wintype.DEBUG_PROCESS
        startupinfo = wintype.STARTUPINFO()
        if bAddr == None:
            self.set_ep_address(path_to_exe)
        else:
            self.ep_address = int(bAddr,16)
        startupinfo.dwFlags = 0x01
        startupinfo.wShowWindow = 0x0
        startupinfo.cb = ctypes.sizeof(startupinfo)

        if ctypes.windll.kernel32.CreateProcessW(
            path_to_exe,
            None,
            None,
            None,
            None,
            creation_flags,
            None,
            None,
            ctypes.byref(startupinfo),
            ctypes.byref(self.process_information)):
            printer.info("成功加载程序!")
            printer.info("程序进程ID[PID]: %d" % self.process_information.dwProcessId)
            self.pid = self.process_information.dwProcessId
            self.h_process = self.open_process(self.process_information.dwProcessId)
            if self.ep_address:
                bp = self.set_software_breakpoint(self.ep_address)
                if bp:
                    printer.info("断点: 0x%08x" % self.ep_address)
                else:
                    printer.warn("断点设置错误!")
            
            self.debugger_active = True
        else:
            printer.warn("load_64_exe错误: 0x%08x." % ctypes.windll.kernel32.GetLastError())
    def open_process(self,pid):
        hprocess = ctypes.windll.kernel32.OpenProcess(wintype.PROCESS_ALL_ACCESS,False,pid)
        return hprocess
    def open_thread64(self,tid):
        h_thread = ctypes.windll.kernel32.OpenThread(wintype.THREAD_ALL_ACCESS64,0,tid)
        if h_thread is not None:
            return h_thread
        else:
            printer.warn("未成功获取当前线程句柄!")
            return False
    def get_thread_context64(self,tid=None,h_thread=None):
        context = wintype.CONTEXT64()
        context.ContextFlags = wintype.CONTEXT_FULL_64
        if h_thread == None:
            self.h_thread = self.open_thread64(tid)
        ctypes.windll.kernel32.SuspendThread(self.h_thread)
        gtc = ctypes.windll.kernel32.GetThreadContext
        gtc.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
        gtc.restypes = ctypes.c_uint
        if gtc(self.h_thread,ctypes.byref(context)):
            ctypes.windll.kernel32.ResumeThread(self.h_thread)
            return context
        else:
            ctypes.windll.kernel32.ResumeThread(self.h_thread)
            printer.warn("get_thread_context64错误: %d" % ctypes.windll.kernel32.GetLastError())
            return False
    def enumerate_threads(self):
        thread_entry = wintype.THREADENTRY32()
        thread_list = []
        snapshot = ctypes.windll.kernel32.CreateToolhelp32Snapshot(
           wintype.TH32CS_SNAPPROCESS | wintype.TH32CS_SNAPTHREAD,0
        )
        if snapshot is not None:
            thread_entry.dwSize = ctypes.sizeof(thread_entry)
            success = ctypes.windll.kernel32.Thread32First(snapshot, ctypes.byref(thread_entry))
            while success:
                if thread_entry.th32OwnerProcessID == self.pid:
                    thread_list.append(thread_entry.th32ThreadID)
                success = ctypes.windll.kernel32.Thread32Next(snapshot, ctypes.byref(thread_entry))
            ctypes.windll.kernel32.CloseHandle(snapshot)
            return thread_list
        else:
            return False
    def general_register_info64(self,ctx):
        printer.info("寄存器信息:")
        space = "   "
        aInfo = "Rax: 0x%016x" % ctx.Rax
        aInfo += space + "Rcx: 0x%016x" % ctx.Rcx
        aInfo += space + "Rdx: 0x%016x" % ctx.Rdx
        bInfo = "Rbx: 0x%016x" % ctx.Rbx
        bInfo += space + "Rsp: 0x%016x" % ctx.Rsp
        self.rsp_address =ctx.Rsp
        bInfo += space + "Rbp: 0x%016x" % ctx.Rbp
        cInfo = "Rsi: 0x%016x" % ctx.Rsi
        cInfo += space + "Rdi: 0x%016x" % ctx.Rdi
        cInfo += space + "Rip: 0x%016x" % ctx.Rip
        self.rip_address = ctx.Rip
        dInfo = "R8: 0x%016x" % ctx.R8
        dInfo += space + "R9: 0x%016x" % ctx.R9
        dInfo += space + "R10: 0x%016x" % ctx.R10
        eInfo = "R11: 0x%016x" % ctx.R11
        eInfo += space + "R12: 0x%016x" % ctx.R12
        eInfo += space + "R13: 0x%016x" % ctx.R13
        fInfo = "R14: 0x%016x" % ctx.R14
        fInfo += space + "R15: 0x%016x" % ctx.R15
        printer.test(aInfo)
        printer.test(bInfo)
        printer.test(cInfo)
        printer.test(dInfo)
        printer.test(eInfo)
        printer.test(fInfo)
    def flags_register_info64(self,ctx):
        printer.info("标志位信息:")
        flags = ctx.EFlags
        if flags is None:
            printer.warn("标志寄存器为空")
        else:
            #i/o Ring0~3
            dEfl = 'iopl=%1d' % ((flags & 0x3000) >> 12)
            #OF
            if flags & 0x800:
                dEfl += ' ov' #溢出标志
            else:
                dEfl += ' no' #无溢出
            #DF
            if flags & 0x400:
                dEfl += ' dn' #向下
            else:
                dEfl += ' up' #向上
            #IF
            if flags & 0x200:
                dEfl += ' ei' #启用中断
            else:
                dEfl += ' di' #停用中断
            #SF
            if flags & 0x80:
                dEfl += ' ng' #正数
            else:
                dEfl += ' pl' #负数
            #ZF
            if flags & 0x40:
                dEfl += ' zr' #等于零
            else:
                dEfl += ' nz' #不等于零
            #AF
            if flags & 0x10:
                dEfl += ' ac' #有辅助进位
            else:
                dEfl += ' na' #无辅助进位
            #PF
            if flags & 0x4:
                dEfl += ' pe' #偶
            else:
                dEfl += ' po' #奇
            #CF
            if flags & 0x1:
                dEfl += ' cy' #有进位
            else:
                dEfl += ' nc' #无进位
            printer.test(dEfl)
    def segment_register_info64(self,ctx):
        printer.info("段寄存器信息:")
        space = "   "
        ainfo = "SegCs: 0x%016x" % ctx.SegCs
        ainfo += space + "SegSs: 0x%016x" % ctx.SegSs
        ainfo += space + "SegDs: 0x%016x" % ctx.SegDs
        binfo = "SegFs: 0x%016x" % ctx.SegFs
        binfo += space + "SegGs: 0x%016x" % ctx.SegGs
        binfo += space + "SegEs: 0x%016x" % ctx.SegEs
        printer.test(ainfo)
        printer.test(binfo)
    def eflags_register(self,ctx):
        printer.info("标志寄存器信息:")
        info = "EFlags: 0x%016x" % ctx.EFlags
        printer.test(info)
    def get_debug_event64(self):
        debug_event = wintype.DEBUG_EVENT()
        continue_status = wintype.DBG_CONTINUE
        if ctypes.windll.kernel32.WaitForDebugEvent(ctypes.byref(debug_event),wintype.INFINITE):
            self.h_thread = self.open_thread64(debug_event.dwThreadId)
            self.context = self.get_thread_context64(h_thread=self.h_thread)
            printer.info("调试代码: %d 线程ID: %d" % (debug_event.dwDebugEventCode, debug_event.dwThreadId))
            if debug_event.dwDebugEventCode ==wintype.EXCEPTION_DEBUG_EVENT:
                self.exception = debug_event.u.Exception.ExceptionRecord.ExceptionCode
                self.exception_address = debug_event.u.Exception.ExceptionRecord.ExceptionAddress
                printer.info("调试事件[EXCEPTION_DEBUG_EVENT]发生的内存地址: 0x%08x" % (self.exception_address))
                if self.exception == wintype.EXCEPTION_ACCESS_VIOLATION:
                    printer.info("触发调试的异常:EXCEPTION_ACCESS_VIOLATION[无效访问]")
                elif self.exception == wintype.EXCEPTION_BREAKPOINT:
                    printer.info("触发调试的异常:EXCEPTION_BREAKPOINT[断点]")
                    continue_status = self.debug_handler_breakpoint(64)
                elif self.exception == wintype.EXCEPTION_SINGLE_STEP:
                    printer.info("触发调试的异常:EXCEPTION_SINGLE_STEP[单步调试]")
                    continue_status = self.debug_single_step(64)
                elif self.exception == wintype.EXCEPTION_GUARD_PAGE:
                    printer.info("触发调试的异常:EXCEPTION_GUARD_PAGE[内存页访问]")
                elif self.exception == wintype.STATUS_WX86_BREAKPOINT:
                    printer.info("触发调试的异常:EXCEPTION_BREAKPOINT[断点]")
                    continue_status = self.debug_handler_breakpoint(64)
                elif self.exception == wintype.STATUS_WX86_SINGLE_STEP:
                    printer.info("触发调试的异常:STATUS_WX86_SINGLE_STEP[单步调试]")
                    continue_status = self.debug_single_step(64)
            if debug_event.dwDebugEventCode == wintype.EXIT_PROCESS_DEBUG_EVENT:
                if self.context:
                    printer.info("Rip: 0x%016x" % self.context.Rip)
                self.stop()
            ctypes.windll.kernel32.ContinueDebugEvent(debug_event.dwProcessId, debug_event.dwThreadId, continue_status)
        else:
            printer.warn("[*] get_debug_event64错误: %d" % ctypes.windll.kernel32.GetLastError())
    def run64(self):
        while self.debugger_active == True:
            if self.context:
                printer.info("Rip: 0x%016x" % self.context.Rip)
            self.get_debug_event64()
    def set_ep_address(self,path):
        pe = pefile.PE(path,fast_load=True)
        ep = pe.OPTIONAL_HEADER.AddressOfEntryPoint
        ep_address = ep + pe.OPTIONAL_HEADER.ImageBase
        self.ep_address = ep_address
    def set_software_breakpoint(self,address):
        printer.info("设置软件断点地址: 0x%08x" % address)
        if address not in self.breakpoints.keys():
            original_byte = self.read_process_memory(address,1)
            if original_byte != False:
                #\xCC 0xCC 中断指令INT 3,用于调试
                if self.write_process_memory(address,b"\xCC"):
                    self.breakpoints[address] = (address,original_byte)
                    return True
                else:
                    return False
            else:
                return False
    def read_process_memory(self,address,length):
        data = ""
        read_buf = ctypes.create_string_buffer(length)
        count = ctypes.c_ulong(0)
        ctypes.windll.kernel32.ReadProcessMemory(self.h_process,address,read_buf,5,ctypes.byref(count))
        data = read_buf.raw
        return data
    def write_process_memory(self,address,data):
        count = ctypes.c_ulong(0)
        length = len(data)
        c_data = ctypes.c_char_p(data[count.value:])
        if not ctypes.windll.kernel32.WriteProcessMemory(self.h_process,address,c_data,length,ctypes.byref(count)):
            return False
        else:
            return True
    def stop(self):
        if ctypes.windll.kernel32.DebugActiveProcessStop(self.pid):
            self.debugger_active = False
            ctypes.windll.kernel32.CloseHandle(self.process_information.hThread)
            ctypes.windll.kernel32.CloseHandle(self.process_information.hProcess)
            printer.warn("动态调试结束...")
        else:
            printer.warn("stop错误: %d" % ctypes.windll.kernel32.GetLastError())
    def debug_handler_breakpoint(self,flag):
        fpath = os.path.join(settings.LIB_HANDLE_PATH,self.hmodule['bp'])
        with open(fpath,"r",encoding="utf-8") as f:
            mod = loader.load_module(f.read())
            continue_status = mod.run(self,flag)
        print(hex(continue_status))
        return continue_status
    def debug_single_step(self,flag):
        fpath = os.path.join(settings.LIB_HANDLE_PATH,self.hmodule['st'])
        with open(fpath,"r",encoding="utf-8") as f:
            mod = loader.load_module(f.read())
            continue_status = mod.run(self,flag)
        return continue_status